# testing/assertsql.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors


from __future__ import annotations

import collections
import contextlib
import itertools
import re

from .. import event
from ..engine import url
from ..engine.default import DefaultDialect
from ..schema import BaseDDLElement


class AssertRule:
    is_consumed = False
    errormessage = None
    consume_statement = True

    def process_statement(self, execute_observed):
        pass

    def no_more_statements(self):
        assert False, (
            "All statements are complete, but pending "
            "assertion rules remain"
        )


class SQLMatchRule(AssertRule):
    pass


class CursorSQL(SQLMatchRule):
    def __init__(self, statement, params=None, consume_statement=True):
        self.statement = statement
        self.params = params
        self.consume_statement = consume_statement

    def process_statement(self, execute_observed):
        stmt = execute_observed.statements[0]
        if self.statement != stmt.statement or (
            self.params is not None and self.params != stmt.parameters
        ):
            self.consume_statement = True
            self.errormessage = (
                "Testing for exact SQL %s parameters %s received %s %s"
                % (
                    self.statement,
                    self.params,
                    stmt.statement,
                    stmt.parameters,
                )
            )
        else:
            execute_observed.statements.pop(0)
            self.is_consumed = True
            if not execute_observed.statements:
                self.consume_statement = True


class CompiledSQL(SQLMatchRule):
    def __init__(
        self, statement, params=None, dialect="default", enable_returning=True
    ):
        self.statement = statement
        self.params = params
        self.dialect = dialect
        self.enable_returning = enable_returning

    def _compare_sql(self, execute_observed, received_statement):
        stmt = re.sub(r"[\n\t]", "", self.statement)
        return received_statement == stmt

    def _compile_dialect(self, execute_observed):
        if self.dialect == "default":
            dialect = DefaultDialect()
            # this is currently what tests are expecting
            # dialect.supports_default_values = True
            dialect.supports_default_metavalue = True

            if self.enable_returning:
                dialect.insert_returning = dialect.update_returning = (
                    dialect.delete_returning
                ) = True
                dialect.use_insertmanyvalues = True
                dialect.supports_multivalues_insert = True
                dialect.update_returning_multifrom = True
                dialect.delete_returning_multifrom = True
                # dialect.favor_returning_over_lastrowid = True
                # dialect.insert_null_pk_still_autoincrements = True

                # this is calculated but we need it to be True for this
                # to look like all the current RETURNING dialects
                assert dialect.insert_executemany_returning

            return dialect
        else:
            return url.URL.create(self.dialect).get_dialect()()

    def _received_statement(self, execute_observed):
        """reconstruct the statement and params in terms
        of a target dialect, which for CompiledSQL is just DefaultDialect."""

        context = execute_observed.context
        compare_dialect = self._compile_dialect(execute_observed)

        # received_statement runs a full compile().  we should not need to
        # consider extracted_parameters; if we do this indicates some state
        # is being sent from a previous cached query, which some misbehaviors
        # in the ORM can cause, see #6881
        cache_key = None  # execute_observed.context.compiled.cache_key
        extracted_parameters = (
            None  # execute_observed.context.extracted_parameters
        )

        if "schema_translate_map" in context.execution_options:
            map_ = context.execution_options["schema_translate_map"]
        else:
            map_ = None

        if isinstance(execute_observed.clauseelement, BaseDDLElement):
            compiled = execute_observed.clauseelement.compile(
                dialect=compare_dialect,
                schema_translate_map=map_,
            )
        else:
            compiled = execute_observed.clauseelement.compile(
                cache_key=cache_key,
                dialect=compare_dialect,
                column_keys=context.compiled.column_keys,
                for_executemany=context.compiled.for_executemany,
                schema_translate_map=map_,
            )
        _received_statement = re.sub(r"[\n\t]", "", str(compiled))
        parameters = execute_observed.parameters

        if not parameters:
            _received_parameters = [
                compiled.construct_params(
                    extracted_parameters=extracted_parameters
                )
            ]
        else:
            _received_parameters = [
                compiled.construct_params(
                    m, extracted_parameters=extracted_parameters
                )
                for m in parameters
            ]

        return _received_statement, _received_parameters

    def process_statement(self, execute_observed):
        context = execute_observed.context

        _received_statement, _received_parameters = self._received_statement(
            execute_observed
        )
        params = self._all_params(context)

        equivalent = self._compare_sql(execute_observed, _received_statement)

        if equivalent:
            if params is not None:
                all_params = list(params)
                all_received = list(_received_parameters)
                while all_params and all_received:
                    param = dict(all_params.pop(0))

                    for idx, received in enumerate(list(all_received)):
                        # do a positive compare only
                        for param_key in param:
                            # a key in param did not match current
                            # 'received'
                            if (
                                param_key not in received
                                or received[param_key] != param[param_key]
                            ):
                                break
                        else:
                            # all keys in param matched 'received';
                            # onto next param
                            del all_received[idx]
                            break
                    else:
                        # param did not match any entry
                        # in all_received
                        equivalent = False
                        break
                if all_params or all_received:
                    equivalent = False

        if equivalent:
            self.is_consumed = True
            self.errormessage = None
        else:
            self.errormessage = self._failure_message(
                execute_observed, params
            ) % {
                "received_statement": _received_statement,
                "received_parameters": _received_parameters,
            }

    def _all_params(self, context):
        if self.params:
            if callable(self.params):
                params = self.params(context)
            else:
                params = self.params
            if not isinstance(params, list):
                params = [params]
            return params
        else:
            return None

    def _failure_message(self, execute_observed, expected_params):
        return (
            "Testing for compiled statement\n%r partial params %s, "
            "received\n%%(received_statement)r with params "
            "%%(received_parameters)r"
            % (
                self.statement.replace("%", "%%"),
                repr(expected_params).replace("%", "%%"),
            )
        )


class RegexSQL(CompiledSQL):
    def __init__(
        self, regex, params=None, dialect="default", enable_returning=False
    ):
        SQLMatchRule.__init__(self)
        self.regex = re.compile(regex)
        self.orig_regex = regex
        self.params = params
        self.dialect = dialect
        self.enable_returning = enable_returning

    def _failure_message(self, execute_observed, expected_params):
        return (
            "Testing for compiled statement ~%r partial params %s, "
            "received %%(received_statement)r with params "
            "%%(received_parameters)r"
            % (
                self.orig_regex.replace("%", "%%"),
                repr(expected_params).replace("%", "%%"),
            )
        )

    def _compare_sql(self, execute_observed, received_statement):
        return bool(self.regex.match(received_statement))


class DialectSQL(CompiledSQL):
    def _compile_dialect(self, execute_observed):
        return execute_observed.context.dialect

    def _compare_no_space(self, real_stmt, received_stmt):
        stmt = re.sub(r"[\n\t]", "", real_stmt)
        return received_stmt == stmt

    def _received_statement(self, execute_observed):
        received_stmt, received_params = super()._received_statement(
            execute_observed
        )

        # TODO: why do we need this part?
        for real_stmt in execute_observed.statements:
            if self._compare_no_space(
                real_stmt.context.statement, received_stmt
            ):
                break
        else:
            raise AssertionError(
                "Can't locate compiled statement %r in list of "
                "statements actually invoked" % received_stmt
            )

        return received_stmt, execute_observed.context.compiled_parameters

    def _dialect_adjusted_statement(self, dialect):
        paramstyle = dialect.paramstyle
        stmt = re.sub(r"[\n\t]", "", self.statement)

        # temporarily escape out PG double colons
        stmt = stmt.replace("::", "!!")

        if paramstyle == "pyformat":
            stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
        else:
            # positional params
            repl = None
            if paramstyle == "qmark":
                repl = "?"
            elif paramstyle == "format":
                repl = r"%s"
            elif paramstyle.startswith("numeric"):
                counter = itertools.count(1)

                num_identifier = "$" if paramstyle == "numeric_dollar" else ":"

                def repl(m):
                    return f"{num_identifier}{next(counter)}"

            stmt = re.sub(r":([\w_]+)", repl, stmt)

        # put them back
        stmt = stmt.replace("!!", "::")

        return stmt

    def _compare_sql(self, execute_observed, received_statement):
        stmt = self._dialect_adjusted_statement(
            execute_observed.context.dialect
        )
        return received_statement == stmt

    def _failure_message(self, execute_observed, expected_params):
        return (
            "Testing for compiled statement\n%r partial params %s, "
            "received\n%%(received_statement)r with params "
            "%%(received_parameters)r"
            % (
                self._dialect_adjusted_statement(
                    execute_observed.context.dialect
                ).replace("%", "%%"),
                repr(expected_params).replace("%", "%%"),
            )
        )


class CountStatements(AssertRule):
    def __init__(self, count):
        self.count = count
        self._statement_count = 0

    def process_statement(self, execute_observed):
        self._statement_count += 1

    def no_more_statements(self):
        if self.count != self._statement_count:
            assert False, "desired statement count %d does not match %d" % (
                self.count,
                self._statement_count,
            )


class AllOf(AssertRule):
    def __init__(self, *rules):
        self.rules = set(rules)

    def process_statement(self, execute_observed):
        for rule in list(self.rules):
            rule.errormessage = None
            rule.process_statement(execute_observed)
            if rule.is_consumed:
                self.rules.discard(rule)
                if not self.rules:
                    self.is_consumed = True
                break
            elif not rule.errormessage:
                # rule is not done yet
                self.errormessage = None
                break
        else:
            self.errormessage = list(self.rules)[0].errormessage


class EachOf(AssertRule):
    def __init__(self, *rules):
        self.rules = list(rules)

    def process_statement(self, execute_observed):
        if not self.rules:
            self.is_consumed = True
            self.consume_statement = False

        while self.rules:
            rule = self.rules[0]
            rule.process_statement(execute_observed)
            if rule.is_consumed:
                self.rules.pop(0)
            elif rule.errormessage:
                self.errormessage = rule.errormessage
            if rule.consume_statement:
                break

        if not self.rules:
            self.is_consumed = True

    def no_more_statements(self):
        if self.rules and not self.rules[0].is_consumed:
            self.rules[0].no_more_statements()
        elif self.rules:
            super().no_more_statements()


class Conditional(EachOf):
    def __init__(self, condition, rules, else_rules):
        if condition:
            super().__init__(*rules)
        else:
            super().__init__(*else_rules)


class Or(AllOf):
    def process_statement(self, execute_observed):
        for rule in self.rules:
            rule.process_statement(execute_observed)
            if rule.is_consumed:
                self.is_consumed = True
                break
        else:
            self.errormessage = list(self.rules)[0].errormessage


class SQLExecuteObserved:
    def __init__(self, context, clauseelement, multiparams, params):
        self.context = context
        self.clauseelement = clauseelement

        if multiparams:
            self.parameters = multiparams
        elif params:
            self.parameters = [params]
        else:
            self.parameters = []
        self.statements = []

    def __repr__(self):
        return str(self.statements)


class SQLCursorExecuteObserved(
    collections.namedtuple(
        "SQLCursorExecuteObserved",
        ["statement", "parameters", "context", "executemany"],
    )
):
    pass


class SQLAsserter:
    def __init__(self):
        self.accumulated = []

    def _close(self):
        self._final = self.accumulated
        del self.accumulated

    def assert_(self, *rules):
        rule = EachOf(*rules)

        observed = list(self._final)
        while observed:
            statement = observed.pop(0)
            rule.process_statement(statement)
            if rule.is_consumed:
                break
            elif rule.errormessage:
                assert False, rule.errormessage
        if observed:
            assert False, "Additional SQL statements remain:\n%s" % observed
        elif not rule.is_consumed:
            rule.no_more_statements()


@contextlib.contextmanager
def assert_engine(engine):
    asserter = SQLAsserter()

    orig = []

    @event.listens_for(engine, "before_execute")
    def connection_execute(
        conn, clauseelement, multiparams, params, execution_options
    ):
        # grab the original statement + params before any cursor
        # execution
        orig[:] = clauseelement, multiparams, params

    @event.listens_for(engine, "after_cursor_execute")
    def cursor_execute(
        conn, cursor, statement, parameters, context, executemany
    ):
        if not context:
            return
        # then grab real cursor statements and associate them all
        # around a single context
        if (
            asserter.accumulated
            and asserter.accumulated[-1].context is context
        ):
            obs = asserter.accumulated[-1]
        else:
            obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
            asserter.accumulated.append(obs)

        obs.statements.append(
            SQLCursorExecuteObserved(
                statement, parameters, context, executemany
            )
        )

    try:
        yield asserter
    finally:
        event.remove(engine, "after_cursor_execute", cursor_execute)
        event.remove(engine, "before_execute", connection_execute)
        asserter._close()
